import os
from pathlib import Path
import yaml
import pickle
from dataclasses import dataclass
import functools
import time

import jax
import jax.numpy as jnp
from flax import serialization

import brax
from brax import envs
import dreamerv3
from dreamerv3 import embodied
from dreamerv3.embodied.envs import from_brax, antwall
from dreamerv3.embodied.core import feat_wrappers
from baselines.qdax import environments
from baselines.qdax.core.neuroevolution.networks.networks import MLPDC
from baselines.qdax.baselines.diayn_smerl import DIAYNSMERL, DiaynSmerlConfig
from baselines.qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer
from baselines.qdax.baselines.sac import SacConfig, SAC
from baselines.qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer, do_iteration_no_jit_fn, warmstart_buffer_no_jit
from baselines.qdax.utils.metrics import CSVLogger

import wandb
import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf


@dataclass
class Config:
    seed: int
    algo_name: str
    num_iterations: int
    log_period: int
    env_batch_size: int
    action_repeat: int


def get_meta_env_ours(config_hydra, run_path):
    run_path = Path(run_path)
    try:
        config_path = list((run_path / "wandb").iterdir())[0] / "files" / "config.yaml"
        with open(config_path) as f:
            config = yaml.safe_load(f)
    except:
        config_path = run_path / "wandb" / "latest-run" / "files" / "config.yaml"
        with open(config_path) as f:
            config = yaml.safe_load(f)

    argv = [
    "--task={}".format(config["task"]["value"]),
    "--feat={}".format(config["feat"]["value"]),
    "--backend={}".format(config["backend"]["value"]),

    "--run.from_checkpoint={}".format(str(run_path / "checkpoint.ckpt")),
    ]

    # Create config
    logdir = str(run_path)
    config = embodied.Config(dreamerv3.configs["defaults"])
    config = config.update(dreamerv3.configs["brax"])
    config = config.update({
    "logdir": logdir,
    "run.train_ratio": 32,
    "run.log_every": 60,  # Seconds
    "batch_size": 16,
    })
    config = embodied.Flags(config).parse(argv=argv)

    # Create logger
    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    logger = embodied.Logger(step, [
    embodied.logger.TerminalOutput(),
    embodied.logger.JSONLOutput(logdir, "metrics.jsonl"),
    embodied.logger.TensorBoardOutput(logdir),
    # embodied.logger.WandBOutput(logdir, config),
    # embodied.logger.MLFlowOutput(logdir.name),
    ])

    # Create environment
    brax.envs.register_environment("antwall", antwall.AntWall)
    env = brax.envs.create(env_name=config.task + "wall",
                           episode_length=config.episode_length,
                           action_repeat=config_hydra.action_repeat,
                           auto_reset=True,
                           batch_size=config_hydra.env_batch_size,
                           backend=config.backend,
                           debug=True)
    env = feat_wrappers.VelocityWrapper(env, config.task)

    # Create agent
    env_space = from_brax.FromBraxVec(env, obs_key="vector", seed=config.seed, n_envs=config_hydra.env_batch_size)
    env_space = dreamerv3.wrap_env(env_space, config)
    agent = dreamerv3.Agent(env_space.obs_space, env_space.act_space, env_space.feat_space, step, config)
    args = embodied.Config(
        **config.run, logdir=config.logdir,
        batch_steps=config.batch_size * config.batch_length)

    checkpoint = embodied.Checkpoint()
    checkpoint.agent = agent
    checkpoint.load(args.from_checkpoint, keys=['agent'])
    policy = lambda *args: agent.policy(*args, mode='eval')

    class MetaEnv(envs.Env):
        def __init__(self, env):
            self.env = env
            self._state = None
            self._step_fn = jax.jit(self.env.step)
            self._reset_fn = jax.jit(self.env.reset)
            self._position_normalization = 50.

        def get_obs(self, state, action):
            return {
                "vector": state.obs[..., 1:],
                "is_first": state.info["steps"] == 0,
                "is_terminal": state.done == 1,
                "feat": state.info["feat"],
                "goal": action,
            }

        def reset(self, rng):
            state = self.env.reset(rng)
            return state.replace(obs=self._augment_obs(state))

        def step(self, state, action):
            obs = self.get_obs(state, self._scale_action(action))
            acts, self._state = policy(obs, self._state)
            next_state = self._step_fn(state.replace(obs=state.obs[..., 1:]), jnp.array(acts["action"]))
            return next_state.replace(
                obs=self._augment_obs(next_state),
                reward=next_state.reward + self._get_reward(state, next_state))

        def _scale_action(self, action):
            return 0.5 * (self.env.behavior_descriptor_limits[1] - self.env.behavior_descriptor_limits[0]) * action + \
                0.5 * (self.env.behavior_descriptor_limits[1] + self.env.behavior_descriptor_limits[0])

        def _augment_obs(self, state):
            return jnp.concatenate([state.pipeline_state.q[..., 0:1]/self._position_normalization, state.obs], axis=-1)

        def _get_reward(self, state, next_state):
            return 0.

        @property
        def action_size(self):
            return self.env.feat_size

        @property
        def observation_size(self):
            return self.env.observation_size + 1

        @property
        def backend(self):
            return self.env.backend

        @property
        def episode_length(self):
            return self.env.episode_length

    return MetaEnv(env)


def get_meta_env_smerl(config_hydra, run_path):
    run_path = Path(run_path)
    with open(run_path / ".hydra" / "config.yaml") as f:
        config = yaml.safe_load(f)
    config = OmegaConf.load(run_path / ".hydra" / "config.yaml")

    # Init a random key
    random_key = jax.random.PRNGKey(config.seed)

    # Init environment
    env = environments.create(config.task + "wall" + "_" + config.feat,
                              episode_length=config.algo.episode_length,
                              action_repeat=config_hydra.action_repeat,
                              batch_size=config_hydra.env_batch_size,
                              backend=config.algo.backend,)

    # Define config
    smerl_config = DiaynSmerlConfig(
        # SAC config
        batch_size=config.algo.batch_size,
        episode_length=config.algo.episode_length,
        tau=config.algo.soft_tau_update,
        normalize_observations=config.algo.normalize_observations,
        learning_rate=config.algo.learning_rate,
        alpha_init=config.algo.alpha_init,
        discount=config.algo.discount,
        reward_scaling=config.algo.reward_scaling,
        hidden_layer_sizes=config.algo.hidden_layer_sizes,
        fix_alpha=config.algo.fix_alpha,
        # DIAYN config
        skill_type=config.algo.skill_type,
        num_skills=config.algo.num_skills,
        descriptor_full_state=config.algo.descriptor_full_state,
        extrinsic_reward=False,
        beta=1.,
        # SMERL
        reverse=False,
        diversity_reward_scale=config.algo.diversity_reward_scale,
        smerl_target=config.algo.smerl_target,
        smerl_margin=config.algo.smerl_margin,
    )

    # Define an instance of DIAYN
    smerl = DIAYNSMERL(config=smerl_config, action_size=env.action_size)

    random_key, random_subkey_1, random_subkey_2 = jax.random.split(random_key, 3)
    fake_obs = jnp.zeros((env.observation_size + config.algo.num_skills,))
    fake_goal = jnp.zeros((config.algo.num_skills,))
    fake_actor_params = smerl._policy.init(random_subkey_1, fake_obs)
    fake_discriminator_params = smerl._discriminator.init(random_subkey_2, fake_goal)

    with open(run_path / "actor/actor.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    actor_params = serialization.from_state_dict(fake_actor_params, state_dict)

    with open(run_path / "discriminator/discriminator.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    discriminator_params = serialization.from_state_dict(fake_discriminator_params, state_dict)

    class MetaEnv(envs.Env):
        def __init__(self, env):
            self.env = env
            self._position_normalization = 50.

        def reset(self, rng):
            state = self.env.reset(rng)
            return state.replace(obs=self._augment_obs(state))

        def step(self, state, action):
            latent_skill, _ = smerl._discriminator.apply(discriminator_params, self._scale_action(action))
            action, _ = smerl.select_action(
                obs=jnp.concatenate([state.obs[..., 1:], latent_skill], axis=-1),
                policy_params=actor_params,
                random_key=None,
                deterministic=True,)
            next_state = self.env.step(state.replace(obs=state.obs[..., 1:]), action)
            return next_state.replace(
                obs=self._augment_obs(next_state),
                reward=next_state.reward + self._get_reward(state, next_state))

        def _scale_action(self, action):
            return 0.5 * (self.env.behavior_descriptor_limits[1] - self.env.behavior_descriptor_limits[0]) * action + \
                0.5 * (self.env.behavior_descriptor_limits[1] + self.env.behavior_descriptor_limits[0])

        def _augment_obs(self, state):
            return jnp.concatenate([state.pipeline_state.q[..., 0:1]/self._position_normalization, state.obs], axis=-1)

        def _get_reward(self, state, next_state):
            return 0.

        @property
        def action_size(self):
            return self.env.feat_size

        @property
        def observation_size(self):
            return self.env.observation_size + 1

        @property
        def backend(self):
            return self.env.backend

        @property
        def episode_length(self):
            return self.env.episode_length

    return MetaEnv(env)


def get_meta_env_smerl_reverse(config_hydra, run_path):
    run_path = Path(run_path)
    with open(run_path / ".hydra" / "config.yaml") as f:
        config = yaml.safe_load(f)
    config = OmegaConf.load(run_path / ".hydra" / "config.yaml")

    # Init a random key
    random_key = jax.random.PRNGKey(config.seed)

    # Init environment
    env = environments.create(config.task + "wall" + "_" + config.feat,
                              episode_length=config.algo.episode_length,
                              action_repeat=config_hydra.action_repeat,
                              batch_size=config_hydra.env_batch_size,
                              backend=config.algo.backend,)

    # Define config
    smerl_config = DiaynSmerlConfig(
        # SAC config
        batch_size=config.algo.batch_size,
        episode_length=config.algo.episode_length,
        tau=config.algo.soft_tau_update,
        normalize_observations=config.algo.normalize_observations,
        learning_rate=config.algo.learning_rate,
        alpha_init=config.algo.alpha_init,
        discount=config.algo.discount,
        reward_scaling=config.algo.reward_scaling,
        hidden_layer_sizes=config.algo.hidden_layer_sizes,
        fix_alpha=config.algo.fix_alpha,
        # DIAYN config
        skill_type=config.algo.skill_type,
        num_skills=config.algo.num_skills,
        descriptor_full_state=config.algo.descriptor_full_state,
        extrinsic_reward=False,
        beta=1.,
        # SMERL
        reverse=True,
        diversity_reward_scale=config.algo.diversity_reward_scale,
        smerl_target=config.algo.smerl_target,
        smerl_margin=config.algo.smerl_margin,
    )

    # Define an instance of DIAYN
    smerl = DIAYNSMERL(config=smerl_config, action_size=env.action_size)

    random_key, random_subkey_1, random_subkey_2 = jax.random.split(random_key, 3)
    fake_obs = jnp.zeros((env.observation_size + config.algo.num_skills,))
    fake_goal = jnp.zeros((config.algo.num_skills,))
    fake_actor_params = smerl._policy.init(random_subkey_1, fake_obs)
    fake_discriminator_params = smerl._discriminator.init(random_subkey_2, fake_goal)

    with open(run_path / "actor/actor.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    actor_params = serialization.from_state_dict(fake_actor_params, state_dict)

    with open(run_path / "discriminator/discriminator.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    discriminator_params = serialization.from_state_dict(fake_discriminator_params, state_dict)

    class MetaEnv(envs.Env):
        def __init__(self, env):
            self.env = env
            self._position_normalization = 50.

        def reset(self, rng):
            state = self.env.reset(rng)
            return state.replace(obs=self._augment_obs(state))

        def step(self, state, action):
            latent_skill, _ = smerl._discriminator.apply(discriminator_params, self._scale_action(action))
            action, _ = smerl.select_action(
                obs=jnp.concatenate([state.obs[..., 1:], latent_skill], axis=-1),
                policy_params=actor_params,
                random_key=None,
                deterministic=True,)
            next_state = self.env.step(state.replace(obs=state.obs[..., 1:]), action)
            return next_state.replace(
                obs=self._augment_obs(next_state),
                reward=next_state.reward + self._get_reward(state, next_state))

        def _scale_action(self, action):
            return 0.5 * (self.env.behavior_descriptor_limits[1] - self.env.behavior_descriptor_limits[0]) * action + \
                0.5 * (self.env.behavior_descriptor_limits[1] + self.env.behavior_descriptor_limits[0])

        def _augment_obs(self, state):
            return jnp.concatenate([state.pipeline_state.q[..., 0:1]/self._position_normalization, state.obs], axis=-1)

        def _get_reward(self, state, next_state):
            return 0.

        @property
        def action_size(self):
            return self.env.feat_size

        @property
        def observation_size(self):
            return self.env.observation_size + 1

        @property
        def backend(self):
            return self.env.backend

        @property
        def episode_length(self):
            return self.env.episode_length

    return MetaEnv(env)


def get_meta_env_uvfa(config_hydra, run_path):
    run_path = Path(run_path)
    try:
        config_path = list((run_path / "wandb").iterdir())[0] / "files" / "config.yaml"
        with open(config_path) as f:
            config = yaml.safe_load(f)
    except:
        config_path = run_path / "wandb" / "latest-run" / "files" / "config.yaml"
        with open(config_path) as f:
            config = yaml.safe_load(f)

    argv = [
    "--task={}".format(config["task"]["value"]),
    "--feat={}".format(config["feat"]["value"]),
    "--backend={}".format(config["backend"]["value"]),

    "--run.from_checkpoint={}".format(str(run_path / "checkpoint.ckpt")),
    ]

    # Create config
    logdir = str(run_path)
    config = embodied.Config(dreamerv3.configs["defaults"])
    config = config.update(dreamerv3.configs["brax"])
    config = config.update({
    "logdir": logdir,
    "run.train_ratio": 32,
    "run.log_every": 60,  # Seconds
    "batch_size": 16,
    })
    config = embodied.Flags(config).parse(argv=argv)

    # Create logger
    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    logger = embodied.Logger(step, [
    embodied.logger.TerminalOutput(),
    embodied.logger.JSONLOutput(logdir, "metrics.jsonl"),
    embodied.logger.TensorBoardOutput(logdir),
    # embodied.logger.WandBOutput(logdir, config),
    # embodied.logger.MLFlowOutput(logdir.name),
    ])

    # Create environment
    brax.envs.register_environment("antwall", antwall.AntWall)
    env = brax.envs.create(env_name=config.task + "wall",
                           episode_length=config.episode_length,
                           action_repeat=config_hydra.action_repeat,
                           auto_reset=True,
                           batch_size=config_hydra.env_batch_size,
                           backend=config.backend,
                           debug=True)
    env = feat_wrappers.VelocityWrapper(env, config.task)

    # Create agent
    env_space = from_brax.FromBraxVec(env, obs_key="vector", seed=config.seed, n_envs=config_hydra.env_batch_size)
    env_space = dreamerv3.wrap_env(env_space, config)
    agent = dreamerv3.Agent(env_space.obs_space, env_space.act_space, env_space.feat_space, step, config)
    args = embodied.Config(
        **config.run, logdir=config.logdir,
        batch_steps=config.batch_size * config.batch_length)

    checkpoint = embodied.Checkpoint()
    checkpoint.agent = agent
    checkpoint.load(args.from_checkpoint, keys=['agent'])
    policy = lambda *args: agent.policy(*args, mode='eval')

    class MetaEnv(envs.Env):
        def __init__(self, env):
            self.env = env
            self._state = None
            self._step_fn = jax.jit(self.env.step)
            self._reset_fn = jax.jit(self.env.reset)
            self._position_normalization = 50.

        def get_obs(self, state, action):
            return {
                "vector": state.obs[..., 1:],
                "is_first": state.info["steps"] == 0,
                "is_terminal": state.done == 1,
                "feat": state.info["feat"],
                "goal": action,
            }

        def reset(self, rng):
            state = self.env.reset(rng)
            return state.replace(obs=self._augment_obs(state))

        def step(self, state, action):
            obs = self.get_obs(state, self._scale_action(action))
            acts, self._state = policy(obs, self._state)
            next_state = self._step_fn(state.replace(obs=state.obs[..., 1:]), jnp.array(acts["action"]))
            return next_state.replace(
                obs=self._augment_obs(next_state),
                reward=next_state.reward + self._get_reward(state, next_state))

        def _scale_action(self, action):
            return 0.5 * (self.env.behavior_descriptor_limits[1] - self.env.behavior_descriptor_limits[0]) * action + \
                0.5 * (self.env.behavior_descriptor_limits[1] + self.env.behavior_descriptor_limits[0])

        def _augment_obs(self, state):
            return jnp.concatenate([state.pipeline_state.q[..., 0:1]/self._position_normalization, state.obs], axis=-1)

        def _get_reward(self, state, next_state):
            return 0.

        @property
        def action_size(self):
            return self.env.feat_size

        @property
        def observation_size(self):
            return self.env.observation_size + 1

        @property
        def backend(self):
            return self.env.backend

        @property
        def episode_length(self):
            return self.env.episode_length

    return MetaEnv(env)


def get_meta_env_dcg_me(config_hydra, run_path):
    run_path = Path(run_path)
    with open(run_path / ".hydra" / "config.yaml") as f:
        config = yaml.safe_load(f)
    config = OmegaConf.load(run_path / ".hydra" / "config.yaml")

    # Init a random key
    random_key = jax.random.PRNGKey(config.seed)

    # Init environment
    env = environments.create(config.task + "wall" + "_" + config.feat,
                              episode_length=config.algo.episode_length,
                              action_repeat=config_hydra.action_repeat,
                              batch_size=config_hydra.env_batch_size,
                              backend=config.algo.backend,)

    # Init policy network
    policy_layer_sizes = config.algo.policy_hidden_layer_sizes + (env.action_size,)
    actor_dc_network = MLPDC(
        layer_sizes=policy_layer_sizes,
        kernel_init=jax.nn.initializers.lecun_uniform(),
        final_activation=jnp.tanh,
    )

    # Init population of controllers
    random_key, subkey = jax.random.split(random_key)
    fake_obs = jnp.zeros(shape=(env.observation_size,))
    fake_desc = jnp.zeros(shape=(env.behavior_descriptor_length,))
    fake_actor_params = actor_dc_network.init(subkey, fake_obs, fake_desc)

    with open(run_path / "actor/actor.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    actor_params = serialization.from_state_dict(fake_actor_params, state_dict)

    class MetaEnv(envs.Env):
        def __init__(self, env):
            self.env = env
            self._position_normalization = 50.

        def reset(self, rng):
            state = self.env.reset(rng)
            return state.replace(obs=self._augment_obs(state))

        def step(self, state, action):
            action = actor_dc_network.apply(actor_params, state.obs[..., 1:], action / env.behavior_descriptor_limits[1][0])
            next_state = self.env.step(state.replace(obs=state.obs[..., 1:]), action)
            return next_state.replace(
                obs=self._augment_obs(next_state),
                reward=next_state.reward + self._get_reward(state, next_state))

        def _scale_action(self, action):
            return 0.5 * (self.env.behavior_descriptor_limits[1] - self.env.behavior_descriptor_limits[0]) * action + \
                0.5 * (self.env.behavior_descriptor_limits[1] + self.env.behavior_descriptor_limits[0])

        def _augment_obs(self, state):
            return jnp.concatenate([state.pipeline_state.q[..., 0:1]/self._position_normalization, state.obs], axis=-1)

        def _get_reward(self, state, next_state):
            return 0.

        @property
        def action_size(self):
            return self.env.feat_size

        @property
        def observation_size(self):
            return self.env.observation_size + 1

        @property
        def backend(self):
            return self.env.backend

        @property
        def episode_length(self):
            return self.env.episode_length

    return MetaEnv(env)


@hydra.main(version_base="1.2", config_path="configs/", config_name="adaptation_wall")
def main(config: Config) -> None:
    wandb.init(
        config=OmegaConf.to_container(config, resolve=True),
        project="Dreamer-GC",
        entity='xxx',
        name="adaptation_wall",
    )

    os.mkdir("./actor/")

    ours_path = "/project/input/ours"
    smerl_path = "/project/input/smerl"
    smerl_reverse_path = "/project/input/smerl_reverse"
    uvfa_path = "/project/input/uvfa"
    dcg_me_path = "/project/input/dcg_me"

    if config.algo_name == "ours":
        env = get_meta_env_ours(config, ours_path)
    elif config.algo_name == "smerl":
        env = get_meta_env_smerl(config, smerl_path)
    elif config.algo_name == "smerl_reverse":
        env = get_meta_env_smerl_reverse(config, smerl_reverse_path)
    elif config.algo_name == "uvfa":
        env = get_meta_env_uvfa(config, uvfa_path)
    elif config.algo_name == "dcg_me":
        env = get_meta_env_dcg_me(config, dcg_me_path)
    else:
        raise NotImplementedError

    # env = environments.create("ant_velocity",
    #                           episode_length=1000,
    #                           batch_size=256,
    #                           backend="spring",)

    reset_fn = jax.jit(env.reset)

    # Init a random key
    random_key = jax.random.PRNGKey(config.seed)

    # Init SAC config
    sac_config = SacConfig(
        batch_size=config.batch_size,
        episode_length=env.episode_length,
        tau=config.soft_tau_update,
        normalize_observations=config.normalize_observations,
        learning_rate=config.learning_rate,
        alpha_init=config.alpha_init,
        discount=config.discount,
        reward_scaling=config.reward_scaling,
        hidden_layer_sizes=config.hidden_layer_sizes,
        fix_alpha=config.fix_alpha,
    )

    # Init SAC
    sac = SAC(config=sac_config, action_size=env.action_size)
    random_key, subkey = jax.random.split(random_key)
    training_state = sac.init(subkey, env.action_size, env.observation_size)

    # Play step functions
    if config.algo_name == "ours" or config.algo_name == "uvfa":
        play_step = functools.partial(
            sac.play_step_no_jit_fn,
            env=env,
            deterministic=False,
        )
        play_warmup_step = functools.partial(
            sac.play_step_no_jit_fn,
            env=env,
            deterministic=False,
        )
        play_eval_step = functools.partial(
            sac.play_step_no_jit_fn,
            env=env,
            deterministic=True,
        )
        eval_policy = functools.partial(
            sac.eval_policy_no_jit_fn,
            play_step_fn=play_eval_step,
        )
    else:
        play_step = functools.partial(
            sac.play_step_fn,
            env=env,
            deterministic=False,
        )
        play_eval_step = functools.partial(
            sac.play_step_fn,
            env=env,
            deterministic=True,
        )
        eval_policy = functools.partial(
            sac.eval_policy_fn,
            play_step_fn=play_eval_step,
        )

    # Init replay buffer
    dummy_transition = QDTransition.init_dummy(
        observation_dim=env.observation_size,
        action_dim=env.action_size,
        descriptor_dim=env.action_size,
    )
    replay_buffer = ReplayBuffer.init(
        buffer_size=config.replay_buffer_size, transition=dummy_transition
    )

    # Iterations
    if config.algo_name == "ours" or config.algo_name == "uvfa":
        random_key, subkey = jax.random.split(random_key)
        env_state = env.reset(subkey)
        replay_buffer, _, training_state = warmstart_buffer_no_jit(
            replay_buffer=replay_buffer,
            training_state=training_state,
            env_state=env_state,
            num_warmstart_steps=config.warmup_steps,
            env_batch_size=config.env_batch_size,
            play_step_fn=play_warmup_step,
        )
        do_iteration = functools.partial(
            do_iteration_no_jit_fn,
            env_batch_size=config.env_batch_size,
            grad_updates_per_step=config.grad_updates_per_step,
            play_step_fn=play_step,
            update_fn=sac.update,
        )
    else:
        random_key, subkey = jax.random.split(random_key)
        env_state = env.reset(subkey)
        replay_buffer, _, training_state = warmstart_buffer(
            replay_buffer=replay_buffer,
            training_state=training_state,
            env_state=env_state,
            num_warmstart_steps=config.warmup_steps,
            env_batch_size=config.env_batch_size,
            play_step_fn=play_step,
        )
        do_iteration = functools.partial(
            do_iteration_fn,
            env_batch_size=config.env_batch_size,
            grad_updates_per_step=config.grad_updates_per_step,
            play_step_fn=play_step,
            update_fn=sac.update,
        )

    metrics = dict.fromkeys(["iteration", "episode_score", "episode_length", "actor_loss", "critic_loss", "alpha_loss", "obs_std", "obs_mean"], jnp.array([]))
    csv_logger = CSVLogger(
        "./log.csv",
        header=list(metrics.keys())
    )

    for i in range(config.num_iterations):
        training_state, env_state, replay_buffer, metrics = do_iteration(
            training_state=training_state,
            env_state=env_state,
            replay_buffer=replay_buffer,
        )

        if i % config.log_period == 0:
            random_key, subkey = jax.random.split(random_key)
            env_state_eval = reset_fn(subkey)
            true_return, _, episode_length_mean = eval_policy(training_state, env_state_eval)

            metrics = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), metrics)
            metrics["iteration"] = i
            metrics["episode_score"] = true_return
            metrics["episode_length"] = episode_length_mean
            csv_logger.log(metrics)

    # Actor
    state_dict = serialization.to_state_dict(training_state.policy_params)
    with open("./actor/actor.pickle", "wb") as params_file:
        pickle.dump(state_dict, params_file)


if __name__ == "__main__":
    cs = ConfigStore.instance()
    cs.store(name="main", node=Config)
    main()
